import copy
import numpy as np

from tqdm import tqdm
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score

import torch
from torch.cuda.amp import GradScaler, autocast

import warnings
warnings.filterwarnings("ignore")


def custom_loss_function(p, q, c, config, epsilon=1e-6, smoothing_factor=0.01):
    smoothing_factor = config['smoothing_factor']
    c.data = torch.clamp(c.data, min=0.0, max=1-epsilon)
    
    if config['cluster_num'] >= 2:
        kl_div = p * torch.log(torch.clamp(p / q, min=epsilon))
        loss = torch.sum(kl_div, dim=2).unsqueeze(2)
    else:
        p = p * (1 - smoothing_factor) + (1 - p) * smoothing_factor
        term1 = p * torch.log(torch.clamp((1 - c**(1-c)) / (1 - c) * (q - c) + c**(1 - c), min=epsilon))
        term2 = (1 - p) * torch.log(torch.clamp(q**(1 - c), min=epsilon))
        
        # loss shape: [batch_size, sequence_size, number_of_clusters]
        loss = -(term1 + term2)
    return loss

def mapping_loss_function(z, c, R, config):
    # z, c shape:                      [batch_size, sequence_size, features_size] & [number_of_clusters, features_size]
    # z_expanded, c_expanded shape:    [batch_size, sequence_size, number_of_clusters, features_size]
    
    z_expanded = z.unsqueeze(2).expand(-1, -1, c.size(0), -1)
    c_expanded = c.unsqueeze(0).unsqueeze(0).expand(z.size(0), z.size(1), -1, -1)
    
    # Calculate the squared distance
    dist = torch.sum((z_expanded - c_expanded) ** 2, dim=-1)
        
    if config['objective'] == 'soft-boundary':
        R_expanded = R.unsqueeze(0).unsqueeze(0).expand(z.size(0), z.size(1), -1)
        scores = dist - R_expanded ** 2 
        mapping_loss = R_expanded ** 2 + (1 / config['nu']) * torch.max(torch.zeros_like(scores), scores)
    else:
        scores = dist
        mapping_loss = dist
        
    # mapping_loss shape: [batch_size, sequence_size, number_of_clusters]
    return dist, scores, mapping_loss

def get_radius(dist: torch.Tensor, nu: float):
    return np.array([np.quantile(np.sqrt(dist[:, :, i].clone().data.cpu().numpy()), 1 - nu) for i in range(dist.shape[2])])

def training(config, data_loader, model, optimizer, optimizer2, epoch_num, R, scaler):
    model.train()
    pbar = tqdm(data_loader, total=len(data_loader), desc=f"Train Epoch {epoch_num}")
    
    avg_mapping_loss = 0
    avg_cluster_loss = 0
    avg_total_loss = 0
    all_dists = []

    for i, (x, _) in enumerate(pbar):
        x = x.to(config['device']) 

        optimizer.zero_grad()
        if config['cluster_num'] == 1:
            optimizer2.zero_grad()
        
        with autocast():
            z, p, q, c, threshold = model(x)
            AnoDEC_loss = custom_loss_function(q, p, threshold, config)
            dist, scores, mapping_loss = mapping_loss_function(z, c, R, config)
            
            # 1. summation of mapping_loss and cluster loss
            sum_loss = mapping_loss + AnoDEC_loss
            
            # 2. summation about dimensionality of cluster
            cluster_sum = torch.sum(sum_loss, dim=-1)  # [batch_size, sequence_size]
            
            # 3. mean of all of samples
            loss = torch.mean(cluster_sum)  # scalar            
            
            # 4. append the dist by each cluster
            all_dists.append(dist.detach().cpu())
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        if config['cluster_num'] == 1:
            scaler.step(optimizer2)
        scaler.update()
    
        avg_mapping_loss = (avg_mapping_loss * i + mapping_loss.mean().item()) / (i + 1)
        avg_cluster_loss = (avg_cluster_loss * i + AnoDEC_loss.sum().item()) / (i + 1)
        avg_total_loss = (avg_total_loss * i + loss.item()) / (i + 1)        
        
        pbar.set_postfix({
            'Thre': ', '.join([f'{t:.5f}' for t in threshold.tolist()])
        })
        
    # concatenate all_dists
    all_dists = torch.cat(all_dists, dim=0)
    return model, threshold, avg_mapping_loss, avg_cluster_loss, avg_total_loss, all_dists

def validation(config, data_loader, model, R):
    loss_dict = {
        'cluster_distance_mapping_loss': 0,
        'sequence_wise_cluster_loss': 0,
    }
    
    model.eval()
    pbar = tqdm(data_loader, total=len(data_loader), desc="Validation")
    all_dists = []
    
    with torch.no_grad():
        for i, (x, _) in enumerate(pbar):
            x = x.to(config['device']) 
            
            z, p, q, c, threshold = model(x)            
            AnoDEC_loss = custom_loss_function(q, p, threshold, config)
            dist, scores, mapping_loss = mapping_loss_function(z, c, R, config)
                        
            loss_dict['cluster_distance_mapping_loss'] += mapping_loss.mean()             
            loss_dict['sequence_wise_cluster_loss'] += AnoDEC_loss.sum() 
            
            all_dists.append(dist.detach().cpu())
    
    loss_dict['cluster_distance_mapping_loss'] /= len(data_loader)
    loss_dict['sequence_wise_cluster_loss'] /= len(data_loader)

    all_dists = torch.cat(all_dists, dim=0)    
    return loss_dict, all_dists

def calculate_anomaly_scores(loader, model, config, R):
    loss_list = []
    model.eval()
    with torch.no_grad():
        for x, _ in tqdm(loader, desc="Calculating scores"):
            x = x.to(config['device'])
            z, p, q, c, threshold = model(x)
            AnoDEC_loss = custom_loss_function(q, p, threshold, config)
            dist, scores, mapping_loss = mapping_loss_function(z, c, R, config)
            
            sum_scores = scores + AnoDEC_loss
            
            cluster_sum = torch.sum(sum_scores, dim=-1)  # [batch_size, sequence_size]
            cluster_sum = cluster_sum.unsqueeze(2)       # [batch_size, sequence_size, 1]
            
            Anomaly_score = cluster_sum
        
            loss_list.append(Anomaly_score.detach().cpu())
    
    # concatenate loss_list
    return torch.cat(loss_list, dim=0)

def adjust_predictions(gt, pred):
    adjusted_pred = pred.copy()
    anomaly_state = False
    for i in range(len(gt)):
        if gt[i] == 1 and pred[i] == 1 and not anomaly_state:
            anomaly_state = True
            # Backward adjustment
            for j in range(i, 0, -1):
                if gt[j] == 0:
                    break
                adjusted_pred[j] = 1
            # Forward adjustment
            for j in range(i, len(gt)):
                if gt[j] == 0:
                    break
                adjusted_pred[j] = 1
        elif gt[i] == 0:
            anomaly_state = False
        if anomaly_state:
            adjusted_pred[i] = 1
    return adjusted_pred

def evaluate_predictions(gt, pred, label=""):
    accuracy = accuracy_score(gt, pred)
    precision, recall, f_score, _ = precision_recall_fscore_support(gt, pred, average='binary')
    tn, fp, fn, tp = confusion_matrix(gt, pred).ravel()
    print()
    print(f"{label} Results:")
    print(f"Accuracy: {accuracy*100:.2f}%, Precision: {precision*100:.2f}%, "
          f"Recall: {recall*100:.2f}%, F-score: {f_score*100:.2f}%")
    print(f"TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}\n")

def testing(config, train_loader, thre_loader, model, R):
    # Calculate anomaly scores for train and threshold sets
    train_scores = calculate_anomaly_scores(train_loader, model, config, R)
    thre_scores = calculate_anomaly_scores(thre_loader, model, config, R)
        
    # Find the threshold for each cluster
    combined_scores = torch.cat([train_scores, thre_scores], dim=0).reshape(-1)
    anomaly_thresholds = np.percentile(combined_scores, 100-config['anormly_ratio'])
    print(f'Anomaly Thresholds: {anomaly_thresholds}')

    test_data = []
    test_scores = []
    test_labels = []

    model.eval()
    with torch.no_grad():
        for x, y in tqdm(thre_loader, desc="Testing"):
            x, y = x.to(config['device']), y.to(config['device'])
            z, p, q, c, threshold = model(x)
            AnoDEC_loss = custom_loss_function(q, p, threshold, config)
            dist, scores, mapping_loss = mapping_loss_function(z, c, R, config)
            
            sum_scores = scores + AnoDEC_loss
            
            cluster_sum = torch.sum(sum_scores, dim=-1)  # [batch_size, sequence_size]
            cluster_sum = cluster_sum.unsqueeze(2)       # [batch_size, sequence_size, 1]
            
            Anomaly_score = cluster_sum
            
            # convert to numpy, and flatten as 1-dim
            # append by each cluster
            test_scores.append(Anomaly_score.detach().cpu())
            test_labels.append(y.detach().cpu())
            test_data.append(x.detach().cpu().reshape(-1, config['feature_size']))

    test_scores = np.array(torch.cat(test_scores, dim=0)).reshape(-1)
    test_labels = np.array(torch.cat(test_labels, dim=0)).reshape(-1)
    test_data = np.array(torch.cat(test_data, dim=0))

    pred = (test_scores > anomaly_thresholds)
    gt = test_labels

    adjusted_pred = adjust_predictions(gt, pred)

    evaluate_predictions(gt, adjusted_pred, "Aftre adjusting")
    print(f'Cut-off: {anomaly_thresholds.item():.5f}, Radius: {[f"{r:.5f}" for r in R.tolist()]}\n')

    
def main_trainer(model, scheduler, optimizer, optimizer2, early_stopping_loss, config, train_loader, valid_loader, test_loader, thre_loader, R):
    scaler = GradScaler()
    
    results = {
        'train_total_loss': [], 'train_mapping_loss': [], 'train_cluster_loss': [],
        'valid_total_loss': [], 'valid_mapping_loss': [], 'valid_cluster_loss': [],
        'radius': [], 'threshold': []
    }
    
    if config['mode'] == 'Train':
        best_loss = np.inf
        
        for epoch in range(config['num_epochs']):
            epoch_num = epoch + 1
            
            model, threshold, avg_mapping_loss, avg_cluster_loss, avg_total_loss, train_dists = training(config, train_loader, model, optimizer, optimizer2, epoch_num, R, scaler)
            scheduler.step()

            VALID_loss_dict, valid_dists = validation(config, valid_loader, model, R)        
            valid_total_loss = VALID_loss_dict['cluster_distance_mapping_loss'] + VALID_loss_dict['sequence_wise_cluster_loss']

            # calculate to new radius by each cluster
            all_dists = torch.cat([train_dists, valid_dists], dim=0)
            R = torch.tensor(get_radius(all_dists, config['nu']), device=config['device'])
    
            results['train_total_loss'].append(avg_total_loss)
            results['train_mapping_loss'].append(avg_mapping_loss)
            results['train_cluster_loss'].append(avg_cluster_loss)
            results['valid_total_loss'].append(valid_total_loss)
            results['valid_mapping_loss'].append(VALID_loss_dict['cluster_distance_mapping_loss'])
            results['valid_cluster_loss'].append(VALID_loss_dict['sequence_wise_cluster_loss'])
            results['radius'].append(R.tolist())
            results['threshold'].append(threshold.tolist())  # threshold를 리스트로 변환하여 저장
            
            print(f'\nEpoch {epoch_num} Summary:')
            print(f'Train - Total Loss: {avg_total_loss:.5f}, Mapping Loss: {avg_mapping_loss:.5f}, Clustering Loss: {avg_cluster_loss:.5f}')
            print(f'Valid - Total Loss: {valid_total_loss:.5f}, '
                  f'Mapping Loss: {VALID_loss_dict["cluster_distance_mapping_loss"]:.5f}, '
                  f'Clustering Loss: {VALID_loss_dict["sequence_wise_cluster_loss"]:.5f}')
            print(f'Threshold: {[f"{t:.5f}" for t in threshold.tolist()]}, Radius: {[f"{r:.5f}" for r in R.tolist()]}\n')
            
            if valid_total_loss < best_loss:
                improvement = best_loss - valid_total_loss
                if improvement >= config['min_delta']:
                    best_loss = valid_total_loss
                    best_idx = epoch_num
                    model_state_dict = model.module.state_dict() if torch.cuda.device_count() > 1 else model.state_dict()
                    best_model_wts = copy.deepcopy(model_state_dict)
                    
                    torch.save({'R': R, 'net_dict': best_model_wts}, f'{config["model_save_name"]}.pt')
                    
                    load_best_model_wts = torch.load(f'{config["model_save_name"]}.pt')

                    if torch.cuda.device_count() > 1:
                        model.module.load_state_dict(load_best_model_wts['net_dict'])
                    else:
                        model.load_state_dict(load_best_model_wts['net_dict'])
                    
                    R = load_best_model_wts['R'].to(config['device'])

                    print(f'==> best model saved {best_idx} epoch / loss : {valid_total_loss:.8f}')
                else:
                    print(f'Loss improved by {improvement:.8f}, but less than min_delta ({config["min_delta"]}). Not saving model.')
            else:
                print(f'Loss did not improve. Current: {valid_total_loss:.8f}, Best: {best_loss:.8f}')

            if early_stopping_loss.step(torch.tensor(valid_total_loss)):
                print("Early stopping")
                break
            
    if config['mode'] == 'Test':
        testing(config, train_loader, thre_loader, model, R)
